{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "twelve-miracle",
   "metadata": {},
   "source": [
    "<a href=\"https://colab.research.google.com/github/PyTorchLightning/lightning-flash/blob/master/flash_notebooks/tabular_classification.ipynb\" target=\"_parent\">\n",
    "    <img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/>\n",
    "</a>"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "genuine-elephant",
   "metadata": {},
   "source": [
    "In this notebook, we'll go over the basics of lightning Flash by training a TabularClassifier on [Titanic Dataset](https://www.kaggle.com/c/titanic).\n",
    "\n",
    "---\n",
    "  - Give us a ⭐ [on Github](https://www.github.com/PytorchLightning/pytorch-lightning/)\n",
    "  - Check out [Flash documentation](https://lightning-flash.readthedocs.io/en/latest/)\n",
    "  - Check out [Lightning documentation](https://pytorch-lightning.readthedocs.io/en/latest/)\n",
    "  - Join us [on Slack](https://join.slack.com/t/pytorch-lightning/shared_invite/zt-pw5v393p-qRaDgEk24~EjiZNBpSQFgQ)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "sorted-dancing",
   "metadata": {},
   "source": [
    "# Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "caring-appreciation",
   "metadata": {},
   "outputs": [],
   "source": [
    "# %%capture\n",
    "! pip install 'git+https://github.com/PyTorchLightning/lightning-flash.git#egg=lightning-flash[tabular]'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "sexual-diabetes",
   "metadata": {},
   "outputs": [],
   "source": [
    "from torchmetrics.classification import Accuracy, Precision, Recall\n",
    "\n",
    "import flash\n",
    "from flash.core.data.utils import download_data\n",
    "from flash.tabular import TabularClassifier, TabularClassificationData"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "boxed-harvest",
   "metadata": {},
   "source": [
    "###  1. Download the data\n",
    "The data are downloaded from a URL, and save in a 'data' directory."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "backed-render",
   "metadata": {},
   "outputs": [],
   "source": [
    "download_data(\"https://pl-flash-data.s3.amazonaws.com/titanic.zip\", 'data/')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "young-arthritis",
   "metadata": {},
   "source": [
    "###  2. Load the data\n",
    "Flash Tasks have built-in DataModules that you can use to organize your data. Pass in a train, validation and test folders and Flash will take care of the rest.\n",
    "\n",
    "Creates a TabularData relies on [Pandas DataFrame](https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.html). "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ultimate-bunny",
   "metadata": {},
   "outputs": [],
   "source": [
    "datamodule = TabularClassificationData.from_csv(\n",
    "    [\"Sex\", \"Age\", \"SibSp\", \"Parch\", \"Ticket\", \"Cabin\", \"Embarked\"],\n",
    "    [\"Fare\"],\n",
    "    target_fields=\"Survived\",\n",
    "    train_file=\"./data/titanic/titanic.csv\",\n",
    "    test_file=\"./data/titanic/test.csv\",\n",
    "    val_split=0.25,\n",
    "    batch_size=8,\n",
    ")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "brutal-hypothesis",
   "metadata": {},
   "source": [
    "###  3. Build the model\n",
    "\n",
    "Note: Categorical columns will be mapped to the embedding space. Embedding space is set of tensors to be trained associated to each categorical column. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "practical-perry",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = TabularClassifier.from_data(datamodule)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "dietary-bowling",
   "metadata": {},
   "source": [
    "###  4. Create the trainer. Run 10 times on data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "integral-interface",
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer = flash.Trainer(max_epochs=10)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "liable-remains",
   "metadata": {},
   "source": [
    "###  5. Train the model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "controversial-newcastle",
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer.fit(model, datamodule=datamodule)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fluid-franchise",
   "metadata": {},
   "source": [
    "###  6. Test model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "therapeutic-bidder",
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer.test(model, datamodule=datamodule)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "genuine-pilot",
   "metadata": {},
   "source": [
    "###  7. Save it!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "alien-stand",
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer.save_checkpoint(\"tabular_classification_model.pt\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "conventional-travel",
   "metadata": {},
   "source": [
    "# Predicting"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "coated-insulation",
   "metadata": {},
   "source": [
    "###  8. Load the model from a checkpoint\n",
    "\n",
    "`TabularClassifier.load_from_checkpoint` supports both url or local_path to a checkpoint. If provided with an url, the checkpoint will first be downloaded and laoded to re-create the model. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "alpine-drilling",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = TabularClassifier.load_from_checkpoint(\n",
    "    \"https://flash-weights.s3.amazonaws.com/0.7.0/tabular_classification_model.pt\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "painted-assistant",
   "metadata": {},
   "source": [
    "###  9. Generate predictions from a sheet file! Who would survive?\n",
    "\n",
    "`TabularClassifier.predict` support both DataFrame and path to `.csv` file."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "located-cable",
   "metadata": {},
   "outputs": [],
   "source": [
    "datamodule = TabularClassificationData.from_csv(\n",
    "    predict_file=\"data/titanic/titanic.csv\",\n",
    "    parameters=datamodule.parameters,\n",
    "    batch_size=8,\n",
    ")\n",
    "predictions = trainer.predict(model, datamodule=datamodule)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "realistic-infection",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(predictions)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "classified-casino",
   "metadata": {},
   "source": [
    "<code style=\"color:#792ee5;\">\n",
    "    <h1> <strong> Congratulations - Time to Join the Community! </strong>  </h1>\n",
    "</code>\n",
    "\n",
    "Congratulations on completing this notebook tutorial! If you enjoyed it and would like to join the Lightning movement, you can do so in the following ways!\n",
    "\n",
    "### Help us build Flash by adding support for new data-types and new tasks.\n",
    "Flash aims at becoming the first task hub, so anyone can get started to great amazing application using deep learning. \n",
    "If you are interested, please open a PR with your contributions !!! \n",
    "\n",
    "\n",
    "### Star [Lightning](https://github.com/PyTorchLightning/pytorch-lightning) on GitHub\n",
    "The easiest way to help our community is just by starring the GitHub repos! This helps raise awareness of the cool tools we're building.\n",
    "\n",
    "* Please, star [Lightning](https://github.com/PyTorchLightning/pytorch-lightning)\n",
    "\n",
    "### Join our [Slack](https://join.slack.com/t/pytorch-lightning/shared_invite/zt-pw5v393p-qRaDgEk24~EjiZNBpSQFgQ)!\n",
    "The best way to keep up to date on the latest advancements is to join our community! Make sure to introduce yourself and share your interests in `#general` channel\n",
    "\n",
    "### Interested by SOTA AI models ! Check out [Bolt](https://github.com/PyTorchLightning/lightning-bolts)\n",
    "Bolts has a collection of state-of-the-art models, all implemented in [Lightning](https://github.com/PyTorchLightning/pytorch-lightning) and can be easily integrated within your own projects.\n",
    "\n",
    "* Please, star [Bolt](https://github.com/PyTorchLightning/lightning-bolts)\n",
    "\n",
    "### Contributions !\n",
    "The best way to contribute to our community is to become a code contributor! At any time you can go to [Lightning](https://github.com/PyTorchLightning/pytorch-lightning) or [Bolt](https://github.com/PyTorchLightning/lightning-bolts) GitHub Issues page and filter for \"good first issue\". \n",
    "\n",
    "* [Lightning good first issue](https://github.com/PyTorchLightning/pytorch-lightning/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22)\n",
    "* [Bolt good first issue](https://github.com/PyTorchLightning/lightning-bolts/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22)\n",
    "* You can also contribute your own notebooks with useful examples !\n",
    "\n",
    "### Great thanks from the entire Pytorch Lightning Team for your interest !\n",
    "\n",
    "<img src=\"https://raw.githubusercontent.com/PyTorchLightning/lightning-flash/18c591747e40a0ad862d4f82943d209b8cc25358/docs/source/_static/images/logo.svg\" width=\"800\" height=\"200\" />"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
